import argparse
import torch
import torchvision.transforms as T

from PIL import Image
from model.model import Yolov1
from model.utils import cellboxes_to_boxes, non_max_suppression, plot_image, CLASSES

LAST_MODEL_FILE = "last.pth.tar"
IMG_WIDTH = 448
IMG_HEIGHT = 448

transform = T.Compose(
    [
        T.Resize((IMG_WIDTH, IMG_HEIGHT)),
        T.ToTensor(),
    ]
)


def main(args):
    # load an image file
    src_image = Image.open(args.img)
    ratio_width = src_image.width / float(IMG_WIDTH)
    ratio_height = src_image.height / float(IMG_HEIGHT)
    print(f"width ratio:{ratio_width}, height ratio: {ratio_height}")

    # load the model
    model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(args.device)
    checkpoint = torch.load(LAST_MODEL_FILE)
    model.load_state_dict(checkpoint["state_dict"])
    print(f"Loaded model: {LAST_MODEL_FILE}")

    # transform into tensor
    image = src_image.copy()
    image = transform(image).unsqueeze(0).to(args.device)

    # inference by model
    model.eval()
    with torch.no_grad():
        predictions = model(image)

    # postprocess the prediction
    bboxes = cellboxes_to_boxes(predictions)
    nms_boxes = non_max_suppression(
        bboxes[0],
        iou_threshold=0.5,
        threshold=0.4,
        box_format="midpoint",
    )

    # plot on the tensor image
    # plot_image(image[0].permute(1, 2, 0).to("cpu"), nms_boxes, None)

    # scale the NMS boxes to plot on the source image
    for box in nms_boxes:
        box[-1] = box[-1] * ratio_height
        box[-2] = box[-2] * ratio_width
        box[-3] = box[-3] * ratio_height
        box[-4] = box[-4] * ratio_width
    plot_image(src_image, nms_boxes, CLASSES)


def parse_args():
    parser = argparse.ArgumentParser(description="Yolo-v1 train")

    # fmt: off
    parser.add_argument("img", help="Image path, include image file, dir and URL")
    parser.add_argument('--device', default='cuda:0', help='Device used for inference')
    # fmt: on

    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)
